
import numpy as np
import mmwave.dsp as dsp
from mmwave.dataloader import DCA1000
import matplotlib.pyplot as plt

plt.close("all")

def NMF(mag_spectrum, nFeatures, diff_threshold):
    """ Performs Non-negative Matrix Factorization on an input signal X (mag_spectrum) using the multiplicative
    update rule formulated by Lee and Seung. 
    
    Args:
        mag_spectrum (np.ndarray): A 2D-array containing the non-mean centered input signal with shape (dimensions,samples)
        nFeatures (int): The number of features NMF should retain
        diff_threshold (np.float): Threshold for amount change in norm between iterations before solution is accepted 
        
    Returns:
        W_NMF (np.ndarray): A 2D-array containing the feature vectors generated by NMF
        H_NMF (np.ndarray): A 2D-array containing the weights generated by NMF
        delta (list): A list containing the norm at each iteration of update (for visualizing convergence)
    """
    
    # Start with random W
    W_NMF = np.random.rand(mag_spectrum.shape[0],nFeatures) # (513,3)
    H_NMF = np.random.rand(nFeatures,mag_spectrum.shape[1]) # (3,349)

    # Iterate until convergence
    iterate = 0
    delta = []

    X_hat = np.matmul(W_NMF,H_NMF)
    delta.append(np.linalg.norm(abs(mag_spectrum-X_hat)))
    Dnorm = 0
    diff = delta[0]

    magni = []
    while diff > diff_threshold:

        Hfrac = np.matmul(W_NMF.T,mag_spectrum)/np.matmul(W_NMF.T,np.matmul(W_NMF,H_NMF))
        H_NMF = np.multiply(H_NMF,Hfrac)

        Wfrac = np.matmul(mag_spectrum,H_NMF.T) / np.matmul(W_NMF,np.matmul(H_NMF,H_NMF.T))
        W_NMF = np.multiply(W_NMF,Wfrac)

        X_hat = np.matmul(W_NMF,H_NMF)

        Dnorm = np.linalg.norm(abs(mag_spectrum-X_hat))
        Mnorm = np.linalg.norm(X_hat)

        diff = abs(Dnorm-delta[-1])
        delta.append(Dnorm)
        magni.append(Mnorm)
        iterate+=1
        
    print("[NMF] NMF converged after ", iterate, " iterations")
    return W_NMF, H_NMF, delta